{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "becd8cf9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# !wget https://huggingface.co/datasets/mesolitica/Malaysian-SFT/resolve/main/combine/combined-malaysian-sft-5k-sample.jsonl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "251334c1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "89ffb39b77fd4f17ad6f28c8ead65638",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Fetching 6 files:   0%|          | 0/6 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "'/home/mesolitica/stt/Malaysian-Instructions'"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from huggingface_hub import snapshot_download\n",
    "\n",
    "snapshot_download(\n",
    "    repo_id=\"mesolitica/Malaysian-Instructions\",\n",
    "    repo_type=\"dataset\",\n",
    "    allow_patterns=[\n",
    "        'data/longer*.parquet',\n",
    "        'data/*manglish*.parquet',\n",
    "        'data/voice*.parquet',\n",
    "    ],\n",
    "    local_dir=\"./Malaysian-Instructions\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "a9291668",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/lib/python3/dist-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.4\n",
      "  warnings.warn(f\"A NumPy version >={np_minversion} and <{np_maxversion}\"\n"
     ]
    }
   ],
   "source": [
    "import librosa\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import pandas as pd\n",
    "from datasets import Audio\n",
    "from transformers import AutoTokenizer\n",
    "from streaming import MDSWriter\n",
    "from streaming.base.format.mds.encodings import Encoding, _encodings\n",
    "from streaming import LocalDataset\n",
    "import streaming\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "from glob import glob\n",
    "import os\n",
    "import json\n",
    "from multiprocess import Pool\n",
    "import itertools\n",
    "\n",
    "def chunks(l, n):\n",
    "    for i in range(0, len(l), n):\n",
    "        yield (l[i: i + n], i // n)\n",
    "\n",
    "def multiprocessing(strings, function, cores=6, returned=True):\n",
    "    df_split = chunks(strings, len(strings) // cores)\n",
    "    pool = Pool(cores)\n",
    "    pooled = pool.map(function, df_split)\n",
    "    pool.close()\n",
    "    pool.join()\n",
    "\n",
    "    if returned:\n",
    "        return list(itertools.chain(*pooled))\n",
    "\n",
    "class UInt32(Encoding):\n",
    "    def encode(self, obj) -> bytes:\n",
    "        return obj.tobytes()\n",
    "\n",
    "    def decode(self, data: bytes):\n",
    "        return np.frombuffer(data, np.uint32)\n",
    "\n",
    "_encodings['uint32'] = UInt32\n",
    "\n",
    "columns = {\n",
    "    'input_ids': 'uint32',\n",
    "    'position_ids': 'uint32',\n",
    "    'attention_mask': 'uint32',\n",
    "    'audio': 'str',\n",
    "    'text': 'str'\n",
    "}\n",
    "hashes = 'sha1', 'xxh64'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "b15fcdc7",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained('mesolitica/Malaysian-Audio-Qwen2.5-7B-Instruct')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "625bfc66",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch_dtype = torch.bfloat16\n",
    "min_dtype = torch.finfo(torch_dtype).min\n",
    "sequence_length = 10240"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "bd7144cc",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "190597it [00:02, 67644.17it/s] \n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "190597"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "combine = []\n",
    "with open('combined-malaysian-sft-5k-sample.jsonl') as fopen:\n",
    "    for l in tqdm(fopen):\n",
    "        l = json.loads(l)\n",
    "        combine.append(l)\n",
    "\n",
    "len(combine)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "4b3716e9",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['Malaysian-Instructions/data/voice_assistant-00000-of-00002.parquet',\n",
       " 'Malaysian-Instructions/data/mixed_manglish-00000-of-00002.parquet',\n",
       " 'Malaysian-Instructions/data/voice_assistant-00001-of-00002.parquet',\n",
       " 'Malaysian-Instructions/data/longer_respond-00000-of-00001.parquet',\n",
       " 'Malaysian-Instructions/data/manglish-00000-of-00001.parquet',\n",
       " 'Malaysian-Instructions/data/mixed_manglish-00001-of-00002.parquet']"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "glob('Malaysian-Instructions/data/*.parquet')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "aa3a7c6b",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 224611/224611 [00:04<00:00, 55575.57it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████| 224610/224610 [00:03<00:00, 65865.74it/s]\n"
     ]
    }
   ],
   "source": [
    "voice_assistant = glob('Malaysian-Instructions/data/voice*.parquet')\n",
    "\n",
    "for f in voice_assistant:\n",
    "    df = pd.read_parquet(f)\n",
    "    for i in tqdm(range(len(df))):\n",
    "        q = json.loads(df['question'].iloc[i])\n",
    "        chat = []\n",
    "        for q_ in q:\n",
    "            if 'content_ms' in q_:\n",
    "                q_['content'] = q_.pop('content_ms')\n",
    "            if q_['content'] is None:\n",
    "                break\n",
    "            chat.append(q_)\n",
    "        if len(chat):\n",
    "            combine.append(chat)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "3f0f3283",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Malaysian-Instructions/data/mixed_manglish-00000-of-00002.parquet\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 165430/165430 [00:06<00:00, 27235.21it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Malaysian-Instructions/data/manglish-00000-of-00001.parquet\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 118466/118466 [00:03<00:00, 30246.02it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Malaysian-Instructions/data/mixed_manglish-00001-of-00002.parquet\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 165430/165430 [00:05<00:00, 30205.07it/s]\n"
     ]
    }
   ],
   "source": [
    "manglish = glob('Malaysian-Instructions/data/*manglish*.parquet')\n",
    "\n",
    "for f in manglish:\n",
    "    print(f)\n",
    "    df = pd.read_parquet(f)\n",
    "    for i in tqdm(range(len(df))):\n",
    "        if df['question'].iloc[i] is None or df['answer'].iloc[i] is None:\n",
    "            continue\n",
    "        \n",
    "        if len(df['question'].iloc[i]) < 10 or len(df['answer'].iloc[i]) < 10:\n",
    "            continue\n",
    "        \n",
    "        chat = [\n",
    "            {'role': 'user', 'content': df['question'].iloc[i]},\n",
    "            {'role': 'assistant', 'content': df['answer'].iloc[i]}\n",
    "        ]\n",
    "        combine.append(chat)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "84984d23",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Malaysian-Instructions/data/longer_respond-00000-of-00001.parquet\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████| 3148/3148 [00:00<00:00, 29761.39it/s]\n"
     ]
    }
   ],
   "source": [
    "manglish = glob('Malaysian-Instructions/data/*longer*.parquet')\n",
    "\n",
    "for f in manglish:\n",
    "    print(f)\n",
    "    df = pd.read_parquet(f)\n",
    "    for i in tqdm(range(len(df))):\n",
    "        if df['question'].iloc[i] is None or df['answer'].iloc[i] is None:\n",
    "            continue\n",
    "        \n",
    "        if len(df['question'].iloc[i]) < 10 or len(df['answer'].iloc[i]) < 10:\n",
    "            continue\n",
    "        \n",
    "        chat = [\n",
    "            {'role': 'system', 'content': \"You are a highly knowledgeable and articulate chatbot. Your primary role is to provide very long, detailed, and precise explanations on any topic the user asks about. Structure your responses with clear logic, thorough reasoning, and factual depth. Avoid oversimplifying complex ideas. If appropriate, break your answers into sections with headings, examples, and breakdowns. Ensure every part of the user's question is fully addressed.\"},\n",
    "            {'role': 'user', 'content': df['question'].iloc[i]},\n",
    "            {'role': 'assistant', 'content': df['answer'].iloc[i]}\n",
    "        ]\n",
    "        combine.append(chat)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "de9963e8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1092292"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(combine)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "9d936d4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import gc\n",
    "\n",
    "def collator(batch, batch_position_ids):\n",
    "    input_ids = []\n",
    "    position_ids = []\n",
    "    masks = []\n",
    "    for i in range(len(batch)):\n",
    "        l = len(batch[i])\n",
    "        input_ids.extend(batch[i])\n",
    "        position_ids.extend(batch_position_ids[i])\n",
    "        masks.append(l)\n",
    "    \n",
    "    return {\n",
    "        'input_ids': np.array(input_ids).astype(np.uint32),\n",
    "        'position_ids': np.array(position_ids).astype(np.uint32),\n",
    "        'attention_mask': np.array(masks).astype(np.uint32),\n",
    "        'audio': '',\n",
    "        'text': '',\n",
    "    }\n",
    "\n",
    "def slice_and_balance(nested_list, size):\n",
    "    first = []\n",
    "    balance = []\n",
    "    current_size = 0\n",
    "\n",
    "    for sublist in nested_list:\n",
    "        if current_size < size:\n",
    "            remaining_space = size - current_size\n",
    "            if len(sublist) <= remaining_space:\n",
    "                first.append(sublist)\n",
    "                current_size += len(sublist)\n",
    "            else:\n",
    "                first.append(sublist[:remaining_space])\n",
    "                balance.append(sublist[remaining_space:])\n",
    "                current_size = size\n",
    "        else:\n",
    "            balance.append(sublist)\n",
    "    \n",
    "    return first, balance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "d9938993",
   "metadata": {},
   "outputs": [],
   "source": [
    "!mkdir tokenized-10k"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "1a29d671",
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "def loop(files, block_size = sequence_length):\n",
    "    rows, index = files\n",
    "    out_root = f'tokenized-10k/tokenized-{index}'\n",
    "    os.system(f'rm -rf {out_root}')\n",
    "    count = 0\n",
    "    temp = []\n",
    "    position_ids = []\n",
    "    last_block, last_position_block = None, None\n",
    "    with MDSWriter(out=out_root, columns=columns, compression=None, hashes=hashes) as out:\n",
    "        for row in tqdm(rows):\n",
    "            prompt = tokenizer.apply_chat_template(row, tokenize=False)\n",
    "            outputs = tokenizer(prompt, add_special_tokens = False)\n",
    "            temp.append(outputs['input_ids'])\n",
    "            position_ids.append(range(len(outputs['input_ids'])))\n",
    "            count += len(outputs['input_ids'])\n",
    "            while count >= block_size:\n",
    "                block, temp = slice_and_balance(temp, block_size)\n",
    "                block_position, position_ids = slice_and_balance(position_ids, block_size)\n",
    "                count = count - block_size\n",
    "                o = collator(block, block_position)\n",
    "                last_block = block\n",
    "                last_position_block = block_position\n",
    "                out.write(o)\n",
    "                \n",
    "        block, _ = slice_and_balance(last_block, block_size - count)\n",
    "        block_position, _ = slice_and_balance(last_position_block, block_size - count)\n",
    "\n",
    "        block.extend(temp)\n",
    "        block_position.extend(position_ids)\n",
    "\n",
    "        o = collator(block, block_position)\n",
    "        if len(o['input_ids']) == block_size:\n",
    "            out.write(o)\n",
    "            return o"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "0187eccb",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████| 50000/50000 [00:44<00:00, 1116.93it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████| 50000/50000 [01:04<00:00, 772.51it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████| 50000/50000 [00:25<00:00, 1932.82it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████| 50000/50000 [01:06<00:00, 752.67it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████| 50000/50000 [00:45<00:00, 1103.63it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████| 50000/50000 [00:47<00:00, 1058.59it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████| 50000/50000 [01:04<00:00, 780.57it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████| 50000/50000 [00:33<00:00, 1485.30it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████| 50000/50000 [00:47<00:00, 1041.72it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████| 50000/50000 [00:49<00:00, 1012.40it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████| 50000/50000 [01:40<00:00, 498.86it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████| 50000/50000 [01:51<00:00, 448.43it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████| 50000/50000 [00:57<00:00, 865.57it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████| 50000/50000 [00:59<00:00, 837.80it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████| 50000/50000 [00:55<00:00, 894.64it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████| 50000/50000 [00:58<00:00, 852.44it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████| 50000/50000 [01:08<00:00, 729.68it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████| 50000/50000 [01:48<00:00, 461.15it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████| 50000/50000 [00:56<00:00, 879.08it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████| 50000/50000 [01:04<00:00, 771.71it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████| 42292/42292 [01:04<00:00, 658.53it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████| 50000/50000 [01:10<00:00, 705.68it/s]\n"
     ]
    }
   ],
   "source": [
    "from multiprocess import Pool\n",
    "\n",
    "chunks = chunks(combine, 50000)\n",
    "pool = Pool(30)\n",
    "pooled = pool.map(loop, chunks)\n",
    "pool.close()\n",
    "pool.join()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "fd5bcbf2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['tokenized-10k/tokenized-0',\n",
       " 'tokenized-10k/tokenized-1',\n",
       " 'tokenized-10k/tokenized-2',\n",
       " 'tokenized-10k/tokenized-3',\n",
       " 'tokenized-10k/tokenized-4',\n",
       " 'tokenized-10k/tokenized-5',\n",
       " 'tokenized-10k/tokenized-6',\n",
       " 'tokenized-10k/tokenized-7',\n",
       " 'tokenized-10k/tokenized-8',\n",
       " 'tokenized-10k/tokenized-9',\n",
       " 'tokenized-10k/tokenized-10',\n",
       " 'tokenized-10k/tokenized-11',\n",
       " 'tokenized-10k/tokenized-12',\n",
       " 'tokenized-10k/tokenized-13',\n",
       " 'tokenized-10k/tokenized-14',\n",
       " 'tokenized-10k/tokenized-15',\n",
       " 'tokenized-10k/tokenized-16',\n",
       " 'tokenized-10k/tokenized-17',\n",
       " 'tokenized-10k/tokenized-18',\n",
       " 'tokenized-10k/tokenized-19',\n",
       " 'tokenized-10k/tokenized-20',\n",
       " 'tokenized-10k/tokenized-21']"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "folders = sorted(glob('tokenized-10k/tokenized-*'), key = lambda x: int(x.split('-')[-1]))\n",
    "folders"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6124767b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
